#include "MOMDP.h"
#include "SparseVector.h"
#include "SparseMatrix.h"
#include "Parser.h"
#include "POMDP.h"
#include "GlobalResource.h"

#include <string>
#include <stdlib.h>
#include <sstream>
#include <fstream>

using namespace std;
using namespace momdp;

void print_usage(const char* cmdName) 
{
    cout << "Usage: " << cmdName << " POMDPModelFileName\n " << endl; 
    cout <<"Example:" <<endl;
    cout << "  " << cmdName << " Hallway.pomdp" << endl;
}

void writeSparseVector(ostream& out, SparseVector& sv, int numStates) {

    int index = 0;
    for(vector<SparseVector_Entry>::const_iterator bi = sv.data.begin(); bi != sv.data.end(); bi++)
    {
	if ( index == (signed) bi->index )
	{
	    //non zero value
	    out << bi->value;
	}
	else
	{
	    //buffer up zero values
	    for ( ;index < (signed) bi->index; index ++ )
	    {
		out << "0 ";
	    }
	    out << bi->value;
	}

	if ( index != numStates-1 )
	{
	    out << " ";
	}
	index ++;
    }

    if ( index < numStates )
    {
	for ( ;index < numStates-1; index ++ )
	{
	    out << "0 ";
	}
	out << "0";
    }
}

//writeout SparseMatrix to POMDPX entries
void writeSparseMatrix(ostream& out, SparseMatrix& sm, SparseMatrix& smtr, int action, char type, int numStates) {
    
    //check sparsity of matrix
    if(sm.data.size() < (sm.size1_ * sm.size2_)/20)
    {
	vector<SparseVector_Entry>::const_iterator  di;
	FOR (c, sm.size2_) {
        SparseCol col = sm.col(c);
	    for (di = col.begin(); di != col.end(); di++) {

		out << "\n<Entry>\n<Instance>";
		out << "a" << action << " " << "s" << di->index << " " << type << c;
		out << "</Instance>\n<ProbTable>" << di->value << "</ProbTable></Entry>";
	    }
	}
    }
    else{
	//use transposed matrix for dumping dense matrix
	
	vector<SparseVector_Entry>::const_iterator  di, col_end;
	out << "\n<Entry>\n<Instance>";
	out << "a" << action << " - - </Instance>\n<ProbTable>" ;

	FOR (c, smtr.size2_) {
	    int index=0;
            SparseCol col = smtr.col(c);
	    for (di = col.begin(); di != col.end(); di++) {

		if ( index == (signed) di->index )
		{
		    //non zero value
		    out << di->value;
		}
		else
		{
		    //buffer up zero values
		    for ( ;index < (signed) di->index; index ++ )
		    {
			out << "0 ";
		    }
		    out << di->value;
		}

		if ( index != numStates-1 )
		{
		    out << " ";
		}
		index ++;
	    }

	    //ensure dense matrix is of correct dimension
	    if ( index < numStates )
	    {
		for ( ;index < numStates-1; index ++ )
		{
		    out << "0 ";
		}
		out << "0";
	    }
	    out << endl;
	}
	
	out << "</ProbTable></Entry>";
    }
}

//writeout SparseMatrix to POMDPX reward entries
void writeSparseMatrixReward(ostream& out, SparseMatrix& sm)
{
    vector<SparseVector_Entry>::const_iterator  di, col_end;
    FOR (c, sm.size2_) {
        SparseCol col = sm.col(c);
	for (di = col.begin(); di != col.end(); di++) {

	    out << "\n<Entry>\n<Instance>";
	    out << "a" << c << " s" << di->index;
	    out << "</Instance>\n<ValueTable>" << di->value << "</ValueTable></Entry>";
	}
    }
}

void convertToPomdpx(POMDP* problem, ofstream& pomdpxfile){

    pomdpxfile << "<?xml version='1.0' encoding='ISO-8859-1'?>\n \
	\n\
	\n\
	<pomdpx version='0.1' id='autogenerated' xmlns:xsi='http://www.w3.org/2001/XMLSchema-instance' xsi:noNamespaceSchemaLocation='pomdpx.xsd'>\n\
	\n\
	\n<Description>This is an auto-generated POMDPX file</Description>\
	\n<Discount>" << problem->getDiscount() << "</Discount>		  \
	\n\
	\n<Variable>\
	\n\
	\n<StateVar vnamePrev=\"state_0\" vnameCurr=\"state_1\" fullyObs=\"false\">\
	\n<NumValues>" << problem->numStates << "</NumValues>\
	\n</StateVar>\
	\n\
	\n<ObsVar vname=\"obs_sensor\">\
	\n<NumValues>" << problem->numObservations << "</NumValues>\
	\n</ObsVar>\
	\n\
	\n<ActionVar vname=\"action_agent\">\
	\n<NumValues>" << problem->T.size() << "</NumValues>\
	\n</ActionVar>\
	\n\
	\n<RewardVar vname=\"reward_agent\"/>\
	\n</Variable>\
	\n\
	\n\
	\n<InitialStateBelief>\
	\n<CondProb>\
	\n<Var>state_0</Var>\
	\n<Parent>null</Parent>\
	\n<Parameter type = \"TBL\">\
	\n<Entry>\
	\n<Instance>-</Instance>\
	\n<ProbTable>";
    writeSparseVector(pomdpxfile, problem->getInitialBelief(), problem->numStates);

    pomdpxfile << "</ProbTable>\n \
	\n</Entry>\
	\n</Parameter>\
	\n</CondProb>\
	\n</InitialStateBelief>";

    pomdpxfile << "<StateTransitionFunction>\n \
	\n<CondProb>\
	\n<Var>state_1</Var>\
	\n<Parent>action_agent state_0</Parent>\
	\n<Parameter type = \"TBL\">";

    for (unsigned int i=0; i < problem->T.size(); i++) {
	SparseMatrix sm = problem->T[i];
	SparseMatrix smtr = problem->Ttr[i];
	writeSparseMatrix(pomdpxfile, sm, smtr, i, 's', problem->numStates);
    }

    pomdpxfile << "\
	\n</Parameter>\
	\n</CondProb>\
	\n</StateTransitionFunction>\n\n";

    pomdpxfile << "<ObsFunction>\n \
	\n<CondProb>\
	\n<Var>obs_sensor</Var>\
	\n<Parent>action_agent state_1</Parent>\
	\n<Parameter type = \"TBL\">";

    for (unsigned int i=0; i < problem->O.size(); i++) {
	SparseMatrix sm = problem->O[i];
	SparseMatrix smtr = problem->Otr[i];
	writeSparseMatrix(pomdpxfile, sm, smtr, i, 'o', problem->numObservations);
    }

    pomdpxfile << "\
	\n</Parameter>\
	\n</CondProb>\
	\n</ObsFunction>\n\n";

    pomdpxfile << "<RewardFunction>\n \
	\n<Func>\
	\n<Var>reward_agent</Var>\
	\n<Parent>action_agent state_0</Parent>\
	\n<Parameter type = \"TBL\">";

    SparseMatrix sm = problem->getRewardMatrix();
    writeSparseMatrixReward(pomdpxfile, sm);
    pomdpxfile << "\
	\n</Parameter>\
	\n</Func>\
	\n</RewardFunction>";

    pomdpxfile << "</pomdpx>";
}


int main(int argc, char **argv) 
{
    try
    {
	SolverParams* p =&GlobalResource::getInstance()->solverParams;
	bool parseCorrect = SolverParams::parseCommandLineOption(argc, argv, *p);
	if(!parseCorrect)
	{
	    print_usage(p->cmdName);
	    exit(EXIT_FAILURE);
	}
	Parser* parser = new Parser();
	POMDP* pomdpProblem = parser->parse(p->problemName, p->useFastParser);

	ofstream pomdpxFile((p->problemName.append("x")).c_str());
	convertToPomdpx(pomdpProblem, pomdpxFile);	
	pomdpxFile.flush();
	pomdpxFile.close();
    }
    catch(bad_alloc &e)
    {
	if(GlobalResource::getInstance()->solverParams.memoryLimit == 0)
	{
	    cout << "Memory allocation failed. Exit." << endl;
	}
	else
	{
	    cout << "Memory limit reached. Please try increase memory limit" << endl;
	}

    }
    catch(exception &e)
    {
	cout << "Exception: " << e.what() << endl ;
    }

    return 0;
}

